Prompt-Based Tuning of Transformer Models for Multi-Center Medical Image Segmentation of Head and Neck Cancer

Medical image segmentation is a vital healthcare endeavor requiring precise and efficient models for appropriate diagnosis and treatment. Vision transformer (ViT)-based segmentation models have shown great performance in accomplishing this task. However, to build a powerful backbone, the self-attention block of ViT requires large-scale pre-training data. The present method of modifying pre-trained models entails updating all or some of the backbone parameters. This paper proposes a novel fine-tuning strategy for adapting a pretrained transformer-based segmentation model on data from a new medical center. This method introduces a small number of learnable parameters, termed prompts, into the input space (less than 1% of model parameters) while keeping the rest of the model parameters frozen. Extensive studies employing data from new unseen medical centers show that the prompt-based fine-tuning of medical segmentation models provides excellent performance regarding the new-center data with a negligible drop regarding the old centers. Additionally, our strategy delivers great accuracy with minimum re-training on new-center data, significantly decreasing the computational and time costs of fine-tuning pre-trained models. Our source code will be made publicly available.


Introduction
Recently, several novel segmentation models have been proposed to assist in medical image analysis and understanding, leading to faster and more accurate treatment planning [1][2][3]. Many of these proposed models are increasingly transformer-based, demonstrating excellent performance on several medical datasets. Transformers are a class of neural network topologies distinguished chiefly by their heavy usage of the attention mechanism [4]. In particular, Vision transformers (ViTs) [5] have demonstrated their ability in 3D medical image segmentation [6,7]. However, ViTs exhibit an intrinsic lack of image-specific inductive bias and scaling behavior; nonetheless, this lack is mitigated by utilizing large datasets and large model capacity.
On the other hand, medical datasets are limited in size due to time-consuming and expensive expert annotations, which hinders the use of powerful transformer models with regard to their full capacity. A common approach to handle the limited data size in the medical domain is to use transfer learning [8]. Multiple studies exploited pretrained networks for different downstream tasks such as classification [9], segmentation [10], and progression [11]. This technique aims to reuse model weights or parameters of already trained ViTs on different but related tasks. More specifically, models are first pretrained on a different large dataset; the pretraining weights act as informed initializations of the model [12][13][14]. The pretrained model is then fine-tuned on the target dataset, yielding faster training and a more generalizable model. However, the limited size of medical datasets is not the only challenge; medical datasets are sourced from different medical centers that use different machines and acquisition protocols, leading to further heterogeneity in the acquired data [15,16]. As a result, a model trained on data obtained from specific medical centers might fail to perform well on data obtained from a new medical center, see Figure 1 (Scenario 1) . Conventionally, we can use the transfer learning technique for adapting the pretrained model to the new medical center data. One such effective adaptation strategy is partial/full fine-tuning, in which some/all of the parameters of the pretrained model are fine-tuned on the new center's data, see Figure 1 (Scenario 2). However, directly fine-tuning a pretrained transformer model on a new center's data can lead to overfitting (as we have mostly small-size datasets from any new center) and catastrophic forgetting (loss of knowledge learned from the previous centers) [17,18]. Hence, this strategy requires storing and deploying a separate copy of the backbone parameters for every newly acquired medical center data. This strategy is costly and infeasible if the end solution is regularly deployed on new medical centers or the acquisition protocol and/or machines in an existing center change. Particularly, this infeasibility will be more prominent in transformer-based models as they are significantly larger than their convolutional neural network (CNN) counterparts. Another possibility is to re-train the model on samples from old and new centers data and re-deploy it upon inference, see Figure 1 (Scenario 3). This scenario is computationally expensive and infeasible due to the same pitfalls of Scenario 2.
In this work, inspired from [19][20][21] we propose a prompt-based fine-tuning method of ViTs on new medical centers' data. It is important to note that previous studies have mainly focused on large language models [19,20] and natural images [21]. However, our research is centered around utilizing prompt-based fine-tuning to tackle medical image segmentation tasks. More specifically, we are looking at multi-class segmentation of cancer lesions with multi-center data. Instead of altering or fine-tuning the pretrained transformer, we introduce center-specific learnable token parameters called prompts in the input space of the segmentation model. Only prompts and the output convolutional layer are learnable during the fine-tuning of the model on the new center's data. The rest of the entire pretrained transformer model is frozen. Current deployment scenarios as well as our proposed approach (Scenario 4) are depicted in Figure 1.
We show that this method can achieve high accuracy on new centers' data with a negligible loss regarding the accuracy of the old centers, in contrast to full or partial finetuning techniques, where the model accuracy comprises the old-center data. The main contributions of this work are as follows: • We propose a new prompt-based fine-tuning technique for the transformer-based medical image segmentation models that reduces the fine-tuning time and the number of learnable parameters (less than 1% of the model parameters) to be stored for the new medical center. • The proposed method achieves equivalent accuracy for new-center data compared to the full fine-tuning technique while mostly preserving the accuracy for the old-center data that compromises full fine-tuning. • We showcase the efficacy of the proposed method on multi-class segmentation of head and neck cancer tumors using multi-channel computed tomography (CT) and positron emission tomography (PET) scans of patients obtained from multi-center (seven centers) sources.  (Scenario 1), the new-center data is directly inferred through the deployed model trained on old-center data (no finetuning). In (Scenario 2), the model is fully or partially finetuned on the new-center data before being deployed for inference. In (Scenario 3), the model is retrained using both old-and new-center data before deployment. Our proposed method (Scenario 4) utilizes the data solely from the new center to finetune only the prompt while keeping the trained model frozen and then deploying it.

Methodology
Due to differences in how imaging is done, what equipment is used, and who the patients are, the quality and distribution of the data collected by different medical centers might be very different. This heterogeneity represents a barrier to developing precise and robust models that can generalize to new medical center data optimally. In this section, we describe a novel tuning technique, called prompt-based tuning, that overcomes the pitfalls of conventional fine-tuning techniques. In this section, we describe prompt-based tuning for adapting transformer-based medical image segmentation models. Prompt-based fine-tuning technique injects a small number of learnable parameters into the transformer's input space and keeps the backbone of the trained model frozen during the downstream training stage. The overall framework is presented in Figure 2. We demonstrate two variants of prompt-based tuning, shallow and deep, and compare their performance to the conventional fine-tuning methods such as partial and full fine-tuning. Below, we describe the two prompt-based tuning methods and highlight the differences between the two.

Shallow Prompt Tuning
In shallow prompt fine-tuning, a set of p continuous prompts of dimension d are introduced in the input space after the embedding layer. These prompts are concatenated with the token embeddings of the volumetric patches of an input image x ∈ R H×W×D×C , where H, W, D, and C are the height, width, depth, and channels of the 3D image, respectively. K × K × K represents the dimensions of each patch, and n = HWD/K 3 is the number of patches extracted. The embedding layer projects these patches to a dimension d. The class token is dropped from the ViT [5] as the experiments are for a segmentation task. The resulting concatenated prompts and embeddings are fed to a transformer encoder consisting of L layers, following the same pipeline as the original ViT [5], with normalization, multi-head self-attention (MSA), and multi-layer perceptron. The decoder only uses image patch embeddings as inputs, and prompt embeddings are discarded. The shallow prompt-based fine-tuning is formulated as: where P is the prompt matrix and ConvTrans3D refers to 3D transpose convolution.

Deep Prompt Tuning
In deep prompt fine-tuning, the prompts can be introduced at the input space of each transformer layer or subset of layers. In our implementation, we add the deep prompts after each skip connection layer: Figure 2. Overview of the proposed method. Learnable prompts are appended to the embedded tokens in the input space and passed through the transformer encoder but not the decoder during the fine-tuning. In deep prompt-based fine-tuning, the learnable prompts are replaced by new prompts after each transformer layer.

Experiments
We use the state-of-the-art transformer-based segmentation models, UNETR [6] and Swin-UNETR [22]. In addition, we compare the two variants of the proposed method to partial and full fine-tuning, two prevalent transfer learning protocols used in medical imaging.

Dataset
The dataset used in this work is multi-center, multi-class, and multi-modal. This dataset comprises head and neck cancer patient scans collected from seven centers. The data consist of CT and PET scans, as well as electronic health records (EHR) of each patient. The PET volume is registered with the CT volume to a common origin, although they each have varying sizes and resolutions. The CT sizes range from (128, 128, 67) to (512, 512, 736), while the PET sizes range from (128, 128, 66) to (256, 256, 543) voxels. The CT resolutions range from (0.488, 0.488, 1.00) to (2.73, 2.73, 2.80), while the PET resolutions range from (2.73, 2.73, 2.00) to (5.47, 5.47, 5.00) mm in the x, y, and z directions. Some scans are of the head and neck regions, while others contain the full body of the patients.
As shown in Figure 3, the PET/CT scans are in the NIFTI format. They have been resampled to 1 × 1 × 1 mm 3 isotropic resolution and cropped to a dimension of 176 × 176 × 176 around the primary tumor and lymph nodes. The CT HU value is clipped to a range of −200 to 200, while the PET is clipped to a maximum of 5 standard uptake values (SUV). The dataset contains segmentation masks for each patient, including the ground truth of primary gross tumor volumes (GTVp), nodal gross tumor volumes (GTVn), and other clinical information. The annotations were made by medical professionals at the respective centers and are provided with the dataset. The dataset is publicly available on the MICCAI 2022 HEad and neCK TumOR (HECKTOR) challenge website [23]. The complete dataset consists of 524 samples. The detailed distribution of the dataset across different centers is listed in Table 1 along with the type of scanner used to acquire the scans.

Experimental Setup
The dataset for each of the seven centers is first split into train and test sets with a ratio of 70:30, respectively, for a fair comparison. In all experiments, the model is first pre-trained using the six centers' training data and then fine-tuned on the seventh center's training data. We evaluate the performance of the model on (1) the seventh center's test set (new center) and (2) on the six centers' test set (old centers). We compare both metrics for the following fine-tuning techniques as shown in Figure 4. No fine-tuning: In this, the pre-trained model is directly used to infer the test samples without any fine-tuning. Partial fine-tuning: This technique involves fine-tuning the pre-trained model's last decoder block using the seventh center's training set. Full fine-tuning: This technique involves fine-tuning the entire pre-trained model using the seventh center's training set. Shallow prompt fine-tuning: This is a variant of prompt-based fine-tuning, where the prompts are introduced only in the input space. Only the prompts and the final convolutional layer are fine-tuned using the seventh center's training set, while the rest of the model is frozen. Deep prompt fine-tuning: This technique is similar to shallow prompt fine-tuning; prompts at each level of the transformer layer are introduced. Thus, at each level, there are new trainable prompts to refine. The prompts and the final convolutional layer are fine-tuned using the seventh center's training set.

Implementation Details
We implement all our models using the PyTorch framework and train them on a single NVIDIA Tesla A6000 GPU. The details of the experimental settings for all fine-tuning techniques are listed in Appendix A Table A1.
All images are aligned to the same 3D orientation (anterior-posterior, right-left, and inferior-superior) during training and testing. The CT/PET scans are concatenated to form a 2-channel input, with their intensity values independently normalized based on their respective means and standard deviations. The training augmentations applied to the CT/PET scans include extracting four random crops of size 96 × 96 × 96, with each having an equal probability of being centered around the primary tumor or lymph node voxels and the background voxels. The images are randomly flipped in the x, y, and z directions, with a probability of 0.2, and are further rotated by 90 degrees in the x and y directions up to 3 times, with a probability of 0.2. These augmentations aim to create more diverse and representative training data, which can help to improve the performance and generalization of deep learning models for medical image analysis tasks. All pre-processing and augmentation details of the data are listed in Appendix A Table A2. Table 2 presents the results of fine-tuning the pre-trained UNETR and Swin-UNETR on the old and new medical center datasets. We conduct our evaluations using a five-fold cross-validation with a total of 290 experiments. The results of all the folds for all the centers can be found in the Supplementary material. We use Dice score [24] to evaluate the performance of segmentation in our experiments. We can observe that:

1.
All the different fine-tuning techniques yield better performance for the new centers than direct inference on the pre-trained models.

2.
Shallow prompt-based fine-tuning achieves a higher or comparable Dice score on the new-center data, with nearly the same number of learnable parameters as partial fine-tuning (see Table 3). However, shallow prompts outperform partial and full fine-tuning techniques on the old-center data for all seven centers.

3.
Deep prompt-based fine-tuning achieves the same Dice score as full fine-tuning on the new-center data but with significantly fewer learnable parameters. In addition, deep prompt-based fine-tuning outperforms the full fine-tuning on old-center data for all seven centers. Thus, even if the storage of model weights is not a concern, prompt-based fine-tuning is still a promising approach for fine-tuning models as it retains more knowledge related to old centers. 4.
The prompt-based fine-tuning of Swin-UNETR exhibits a similar pattern to that of UNETR. However, the loss in performance on old-center data for the conventional finetuning methods is less prominent for some centers compared to that of UNETR. This can be explained by the inductive biases in Swin-UNETR, which employs MSA within local shifted windows and merges patch embeddings at deeper layers. Swin-UNETR requires further optimization with regard to prompt position to further improve its performance.

Discussion
This work introduces a new method for fine-tuning transformer-based medical segmentation models on new-center data. Our method is more efficient than conventional approaches, requiring fewer parameters at a lower computational cost while achieving the same or better performance on new-center data when compared to conventional methods (Table A3). We show superior performance for prompt-based fine-tuning compared to other techniques, achieving a statistically significant increase in the Dice score for old centers. We note the difference in performance between CHUP and CHUS, which have a similar number of samples but different acquisition machines and origins. CHUP exhibits a larger drop in performance on the old centers than CHUS (nearly 8% in CHUP vs. 1% in CHUS for partial and full fine-tuning). This is likely due to the larger dataset distribution shift in CHUP compared to the rest of the centers. However, if shallow-or deep prompt-based fine-tuning is used, the drop is only 2-3%. We perform a Wilcoxon signed-rank test [25] to assess whether the deep prompt-based tuning of medical segmentation models is significantly better than other fine-tuning techniques on old-and new-center data (the null hypothesis H 0 states that the segmentation performance of deep prompt-based fine-tuning is statistically the same as the other techniques. The alternative hypothesis H 1 states that the deep prompt-based technique outperforms the other methods). Table 4 presents the results of each test; it can be observed that deep prompt-based fine-tuning outperforms full and partial fine-tuning techniques on the old center's data. Similarly, it outperforms the partial prompt-and shallow prompt-based techniques on the new-center data. However, the test fails on the new center's data for full fine-tuning. Thus, we proceed to performing a two-tailed t-test and confirm that the performances of deep prompt-based fine-tuning and full fine-tuning on new-center data are statistically the same (p-value < 0.05). Table 4. Wilcoxon signed-rank test on whether deep prompt-based fine-tuning of UNETR performance is better than the other methods.
Is the performance of deep prompt-based fine-tuned models on old centers statistically better than Is the performance of deep prompt-based fine-tuned models on new centers statistically better than In our experiments, we observed that the extra learnable prompts at deeper layers in the deep prompt-based fine-tuning improve the performance compared to shallow prompt-based fine-tuning, which only inserts prompts in the input space after the patch embedding layer. We present the results of ablating different prompt positions and prompt numbers in Tables A11-A13. Our findings indicate that their specific position does not significantly influence the model's performance when the number of prompts is fixed. However, for a fixed number of prompts distributed across various layers, incorporating prompts into the skip connection layers adversely affects the model's performance, while their exclusion leads to performance improvements, as shown in Table A12. Furthermore, the results reveal that increasing the number of prompts initially yields improvements in performance. However, there is a threshold beyond which the model tends to become overparameterized, resulting in a degradation of its performance. These results serve as motivation for our choice to position the deep prompts after the skip connection layers in our design. This suggests that adding too many prompts in the deeper layers can over-parameterize the model, which may result in overfitting on new-center data. Further studies will be conducted to quantify the effect of the number and position of the prompts.

Conclusions
We propose a prompt-based fine-tuning framework for the medical image segmentation problem. This method takes advantage of the strength of transformers to handle a variable number of tokens at the input and the deeper layers. We validate our proposed method by training transformer-based segmentation models on head and neck PET/CT scans and compare our results with conventional fine-tuning techniques. Although we were able to show the efficacy of the proposed method on medical image segmentation problems, further investigation is needed to study its scalability to other transformer-based segmentation models in the future. In addition, investigation of prompt-based learning in different tasks, such as classification and prognosis, is needed to assess its efficacy, along with its performance comparison with domain generalization methods. Data Availability Statement: Data used in this study were obtained from the Head and Neck Tumor Segmentation and Outcome Prediction in the PET/CT Images challenge [23].

Conflicts of Interest:
The authors declare no conflict of interest. The funders had no role in the design of the study; in the collection, analyses, or interpretation of data; in the writing of the manuscript; or in the decision to publish the results.

Abbreviations
The following abbreviations are used in this manuscript: